{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 🌀 RealNVP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own RealNVP network to predict the distribution of a demo dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2db9a506-bf8f-4a40-ab27-703b0c82371b",
   "metadata": {},
   "source": [
    "The code has been adapted from the excellent [RealNVP tutorial](https://keras.io/examples/generative/real_nvp) created by Mandolini Giorgio Maria, Sanna Daniele and Zannini Quirini Giorgio available on the Keras website."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn import datasets\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import (\n",
    "    layers,\n",
    "    models,\n",
    "    regularizers,\n",
    "    metrics,\n",
    "    optimizers,\n",
    "    callbacks,\n",
    ")\n",
    "import tensorflow_probability as tfp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "COUPLING_DIM = 256\n",
    "COUPLING_LAYERS = 2\n",
    "INPUT_DIM = 2\n",
    "REGULARIZATION = 0.01\n",
    "BATCH_SIZE = 256\n",
    "EPOCHS = 300"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "data = datasets.make_moons(30000, noise=0.05)[0].astype(\"float32\")\n",
    "norm = layers.Normalization()\n",
    "norm.adapt(data)\n",
    "normalized_data = norm(data)\n",
    "plt.scatter(\n",
    "    normalized_data.numpy()[:, 0], normalized_data.numpy()[:, 1], c=\"green\"\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2. Build the RealNVP network <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Coupling(input_dim, coupling_dim, reg):\n",
    "    input_layer = layers.Input(shape=input_dim)\n",
    "\n",
    "    s_layer_1 = layers.Dense(\n",
    "        coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(input_layer)\n",
    "    s_layer_2 = layers.Dense(\n",
    "        coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(s_layer_1)\n",
    "    s_layer_3 = layers.Dense(\n",
    "        coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(s_layer_2)\n",
    "    s_layer_4 = layers.Dense(\n",
    "        coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(s_layer_3)\n",
    "    s_layer_5 = layers.Dense(\n",
    "        input_dim, activation=\"tanh\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(s_layer_4)\n",
    "\n",
    "    t_layer_1 = layers.Dense(\n",
    "        coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(input_layer)\n",
    "    t_layer_2 = layers.Dense(\n",
    "        coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(t_layer_1)\n",
    "    t_layer_3 = layers.Dense(\n",
    "        coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(t_layer_2)\n",
    "    t_layer_4 = layers.Dense(\n",
    "        coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(t_layer_3)\n",
    "    t_layer_5 = layers.Dense(\n",
    "        input_dim, activation=\"linear\", kernel_regularizer=regularizers.l2(reg)\n",
    "    )(t_layer_4)\n",
    "\n",
    "    return models.Model(inputs=input_layer, outputs=[s_layer_5, t_layer_5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f4dcd44-4189-4f39-b262-7afedb00a5a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RealNVP(models.Model):\n",
    "    def __init__(\n",
    "        self, input_dim, coupling_layers, coupling_dim, regularization\n",
    "    ):\n",
    "        super(RealNVP, self).__init__()\n",
    "        self.coupling_layers = coupling_layers\n",
    "        self.distribution = tfp.distributions.MultivariateNormalDiag(\n",
    "            loc=[0.0, 0.0], scale_diag=[1.0, 1.0]\n",
    "        )\n",
    "        self.masks = np.array(\n",
    "            [[0, 1], [1, 0]] * (coupling_layers // 2), dtype=\"float32\"\n",
    "        )\n",
    "        self.loss_tracker = metrics.Mean(name=\"loss\")\n",
    "        self.layers_list = [\n",
    "            Coupling(input_dim, coupling_dim, regularization)\n",
    "            for i in range(coupling_layers)\n",
    "        ]\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        return [self.loss_tracker]\n",
    "\n",
    "    def call(self, x, training=True):\n",
    "        log_det_inv = 0\n",
    "        direction = 1\n",
    "        if training:\n",
    "            direction = -1\n",
    "        for i in range(self.coupling_layers)[::direction]:\n",
    "            x_masked = x * self.masks[i]\n",
    "            reversed_mask = 1 - self.masks[i]\n",
    "            s, t = self.layers_list[i](x_masked)\n",
    "            s *= reversed_mask\n",
    "            t *= reversed_mask\n",
    "            gate = (direction - 1) / 2\n",
    "            x = (\n",
    "                reversed_mask\n",
    "                * (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s))\n",
    "                + x_masked\n",
    "            )\n",
    "            log_det_inv += gate * tf.reduce_sum(s, axis=1)\n",
    "        return x, log_det_inv\n",
    "\n",
    "    def log_loss(self, x):\n",
    "        y, logdet = self(x)\n",
    "        log_likelihood = self.distribution.log_prob(y) + logdet\n",
    "        return -tf.reduce_mean(log_likelihood)\n",
    "\n",
    "    def train_step(self, data):\n",
    "        with tf.GradientTape() as tape:\n",
    "            loss = self.log_loss(data)\n",
    "        g = tape.gradient(loss, self.trainable_variables)\n",
    "        self.optimizer.apply_gradients(zip(g, self.trainable_variables))\n",
    "        self.loss_tracker.update_state(loss)\n",
    "        return {\"loss\": self.loss_tracker.result()}\n",
    "\n",
    "    def test_step(self, data):\n",
    "        loss = self.log_loss(data)\n",
    "        self.loss_tracker.update_state(loss)\n",
    "        return {\"loss\": self.loss_tracker.result()}\n",
    "\n",
    "\n",
    "model = RealNVP(\n",
    "    input_dim=INPUT_DIM,\n",
    "    coupling_layers=COUPLING_LAYERS,\n",
    "    coupling_dim=COUPLING_DIM,\n",
    "    regularization=REGULARIZATION,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b14665-4359-447b-be58-3fd58ba69084",
   "metadata": {},
   "source": [
    "## 3. Train the RealNVP network <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile and train the model\n",
    "model.compile(optimizer=optimizers.Adam(learning_rate=0.0001))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class ImageGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_samples):\n",
    "        self.num_samples = num_samples\n",
    "\n",
    "    def generate(self):\n",
    "        # From data to latent space.\n",
    "        z, _ = model(normalized_data)\n",
    "\n",
    "        # From latent space to data.\n",
    "        samples = model.distribution.sample(self.num_samples)\n",
    "        x, _ = model.predict(samples, verbose=0)\n",
    "\n",
    "        return x, z, samples\n",
    "\n",
    "    def display(self, x, z, samples, save_to=None):\n",
    "        f, axes = plt.subplots(2, 2)\n",
    "        f.set_size_inches(8, 5)\n",
    "\n",
    "        axes[0, 0].scatter(\n",
    "            normalized_data[:, 0], normalized_data[:, 1], color=\"r\", s=1\n",
    "        )\n",
    "        axes[0, 0].set(title=\"Data space X\", xlabel=\"x_1\", ylabel=\"x_2\")\n",
    "        axes[0, 0].set_xlim([-2, 2])\n",
    "        axes[0, 0].set_ylim([-2, 2])\n",
    "        axes[0, 1].scatter(z[:, 0], z[:, 1], color=\"r\", s=1)\n",
    "        axes[0, 1].set(title=\"f(X)\", xlabel=\"z_1\", ylabel=\"z_2\")\n",
    "        axes[0, 1].set_xlim([-2, 2])\n",
    "        axes[0, 1].set_ylim([-2, 2])\n",
    "        axes[1, 0].scatter(samples[:, 0], samples[:, 1], color=\"g\", s=1)\n",
    "        axes[1, 0].set(title=\"Latent space Z\", xlabel=\"z_1\", ylabel=\"z_2\")\n",
    "        axes[1, 0].set_xlim([-2, 2])\n",
    "        axes[1, 0].set_ylim([-2, 2])\n",
    "        axes[1, 1].scatter(x[:, 0], x[:, 1], color=\"g\", s=1)\n",
    "        axes[1, 1].set(title=\"g(Z)\", xlabel=\"x_1\", ylabel=\"x_2\")\n",
    "        axes[1, 1].set_xlim([-2, 2])\n",
    "        axes[1, 1].set_ylim([-2, 2])\n",
    "\n",
    "        plt.subplots_adjust(wspace=0.3, hspace=0.6)\n",
    "        if save_to:\n",
    "            plt.savefig(save_to)\n",
    "            print(f\"\\nSaved to {save_to}\")\n",
    "\n",
    "        plt.show()\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        if epoch % 10 == 0:\n",
    "            x, z, samples = self.generate()\n",
    "            self.display(\n",
    "                x,\n",
    "                z,\n",
    "                samples,\n",
    "                save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
    "            )\n",
    "\n",
    "\n",
    "img_generator_callback = ImageGenerator(num_samples=3000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    normalized_data,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    epochs=EPOCHS,\n",
    "    callbacks=[tensorboard_callback, img_generator_callback],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc",
   "metadata": {},
   "source": [
    "## 4. Generate images <a name=\"generate\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db3cfe3-339e-463d-8af5-fbd403385fca",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, z, samples = img_generator_callback.generate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80087297-3f47-4e0c-ac89-8758d4386d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_generator_callback.display(x, z, samples)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
